{
 "cells": [
  {
   "cell_type": "code",
   "execution_count": 2,
   "metadata": {},
   "outputs": [],
   "source": [
    "import torch \n",
    "import torch.nn as nn \n",
    "from glove import *\n",
    "import torch.optim as optim\n",
    "import torch.nn.functional as F\n",
    "from sklearn.manifold import TSNE\n",
    "import numpy as np\n",
    "from collections import Counter, defaultdict\n",
    "import matplotlib.pyplot as plt\n"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 3,
   "metadata": {},
   "outputs": [],
   "source": [
    "class GloveModel(nn.Module):\n",
    "    def __init__(self, num_embedding, embedding_dim):\n",
    "        super(GloveModel, self).__init__()\n",
    "        self.wi = nn.Embedding(num_embedding, embedding_dim)\n",
    "        self.wj = nn.Embedding(num_embedding, embedding_dim)\n",
    "        self.bi = nn.Embedding(num_embedding, 1)\n",
    "        self.bj = nn.Embedding(num_embedding, 1)\n",
    "\n",
    "    def forward(self, i_indices, j_indices):\n",
    "        w_i = self.wi(i_indices)\n",
    "        w_j = self.wj(j_indices)\n",
    "        b_i = self.bi(i_indices).squeeze()\n",
    "        b_j = self.bj(j_indices).squeeze()\n",
    "        x = torch.sum(w_i * w_j) + b_i + b_j\n",
    "        return x\n",
    "\n",
    "def weight_func(x, x_max, alpha):\n",
    "    wx = (x/x_max)**alpha\n",
    "    wx = torch.min(wx, torch.ones_like(wx))\n",
    "    return wx\n",
    "\n",
    "#     return wx.cuda()\n",
    "\n",
    "def wmse_loss(weights, inputs, targets):\n",
    "    loss = weights * F.mse_loss(inputs, targets, reduction='none')\n",
    "    return torch.mean(loss)\n",
    "#     return torch.mean(loss).cuda()"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 6,
   "metadata": {},
   "outputs": [],
   "source": [
    "\n",
    "class GloveDataset:\n",
    "    def __init__(self, text, n_words=200000, window_size=5):\n",
    "        self._window_size = window_size\n",
    "        self._tokens = text.split(' ')[:n_words]\n",
    "        word_counter = Counter()\n",
    "        word_counter.update(self._tokens)\n",
    "        self._word2id = {w:i for i, (w,_) in enumerate(word_counter.most_common())}\n",
    "        self._id2word = {i:w for w, i in self._word2id.items()}\n",
    "        self._vocab_len = len(self._word2id)\n",
    "        self._id_tokens = [self._word2id[w] for w in self._tokens]\n",
    "        self._create_coocurrence_matrix()\n",
    "\n",
    "        print('# of words: {}'.format(len(self._tokens)))\n",
    "        print('Vocabulary length: {}'.format(self._vocab_len))\n",
    "\n",
    "    def _create_coocurrence_matrix(self):\n",
    "        cooc_mat = defaultdict(Counter)\n",
    "\n",
    "        for i,w in enumerate(self._id_tokens):\n",
    "            start_i = max(i - self._window_size, 0)\n",
    "            end_i = min(i + self._window_size +1, len(self._id_tokens))\n",
    "\n",
    "            for j in range(start_i, end_i):\n",
    "                if i != j:\n",
    "                    c = self._id_tokens[j]\n",
    "                    cooc_mat[w][c] += 1 / abs(j-i)\n",
    "\n",
    "        self._i_idx = list()\n",
    "        self._j_idx = list()\n",
    "        self._xij = list()\n",
    "\n",
    "        for w, cnt in cooc_mat.items():\n",
    "            for c, v in cnt.items():\n",
    "                self._i_idx.append(w)\n",
    "                self._j_idx.append(c)\n",
    "                self._xij.append(v)\n",
    "        self._i_idx = torch.LongTensor(self._i_idx)\n",
    "        self._j_idx = torch.LongTensor(self._j_idx)\n",
    "        self._xij = torch.LongTensor(self._xij)\n",
    "\n",
    "#         self._i_idx = torch.LongTensor(self._i_idx).cuda()\n",
    "#         self._j_idx = torch.LongTensor(self._j_idx).cuda()\n",
    "#         self._xij = torch.LongTensor(self._xij).cuda()\n",
    "\n",
    "\n",
    "    def get_batches(self, batch_size):\n",
    "        rand_ids= torch.LongTensor(np.random.choice(len(self._xij), len(self._xij), replace=False))\n",
    "        for p in range(0, len(rand_ids), batch_size):\n",
    "            batch_ids = rand_ids[p:p+batch_size]\n",
    "            yield self._xij[batch_ids], self._i_idx[batch_ids], self._j_idx[batch_ids]\n",
    "\n",
    "\n",
    "\n"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 7,
   "metadata": {},
   "outputs": [
    {
     "name": "stdout",
     "output_type": "stream",
     "text": [
      "# of words: 1000000\n",
      "Vocabulary length: 52755\n"
     ]
    }
   ],
   "source": [
    "dataset = GloveDataset(open(\"text.txt\").read(), 1000000)\n",
    "device = torch.device(\"cuda:0\" if torch.cuda.is_available() else \"cpu\")\n",
    "EMBED_DIM = 300\n",
    "glove = GloveModel(dataset._vocab_len, EMBED_DIM)\n",
    "glove = glove.to(device)"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 12,
   "metadata": {},
   "outputs": [
    {
     "name": "stdout",
     "output_type": "stream",
     "text": [
      "Epoch: 1/1 \t Batch: 500/1724 \t Loss: 736.8017051875592\n",
      "Epoch: 1/1 \t Batch: 1000/1724 \t Loss: 205.43506680727006\n",
      "Epoch: 1/1 \t Batch: 1500/1724 \t Loss: 373.33878472447395\n",
      "Epoch loss: 390.70504263611014\n",
      "Saving model...\n"
     ]
    }
   ],
   "source": [
    "optimizer = optim.Adagrad(glove.parameters(), lr=0.001)\n",
    "N_EPOCHS = 1\n",
    "BATCH_SIZE = 2048\n",
    "X_MAX = 100\n",
    "ALPHA = 0.75\n",
    "n_batches = int(len(dataset._xij) / BATCH_SIZE)\n",
    "loss_values = list()\n",
    "for e in range(1, N_EPOCHS+1):\n",
    "    batch_i = 0\n",
    "    loss_values = []\n",
    "    for x_ij, i_idx, j_idx in dataset.get_batches(BATCH_SIZE):\n",
    "#         print(batch_i)\n",
    "        x_ij = x_ij.to(device)\n",
    "        i_idx = i_idx.to(device)\n",
    "        j_idx = j_idx.to(device)\n",
    "        \n",
    "        batch_i += 1\n",
    "        \n",
    "        optimizer.zero_grad()\n",
    "        \n",
    "        outputs = glove(i_idx, j_idx)\n",
    "        weights_x = weight_func(x_ij, X_MAX, ALPHA)\n",
    "        #print(torch.log(x_ij.float()+1))\n",
    "        loss = wmse_loss(weights_x, outputs, torch.log(x_ij.float()+1))\n",
    "        \n",
    "        loss.backward()\n",
    "        \n",
    "        optimizer.step()\n",
    "        \n",
    "        loss_values.append(loss.item())\n",
    "        \n",
    "        if batch_i % 500 == 0:\n",
    "            print(\"Epoch: {}/{} \\t Batch: {}/{} \\t Loss: {}\".format(e, N_EPOCHS, batch_i, n_batches, np.mean(loss_values[-20:])))  \n",
    "    print('Epoch loss: {}'.format(np.mean(loss_values)))\n",
    "    print(\"Saving model...\")\n",
    "    torch.save(glove.state_dict(), \"text.pt\")"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {},
   "outputs": [],
   "source": [
    "i_idx.size()"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {},
   "outputs": [],
   "source": []
  }
 ],
 "metadata": {
  "kernelspec": {
   "display_name": "Python 3",
   "language": "python",
   "name": "python3"
  },
  "language_info": {
   "codemirror_mode": {
    "name": "ipython",
    "version": 3
   },
   "file_extension": ".py",
   "mimetype": "text/x-python",
   "name": "python",
   "nbconvert_exporter": "python",
   "pygments_lexer": "ipython3",
   "version": "3.7.6"
  }
 },
 "nbformat": 4,
 "nbformat_minor": 4
}
